%% Two-sided heterogeneity and trade : Quantification

% Load data - firm level
clear
workdir = 'two_sided/matlab';
path(workdir,path);

manuf = '_manuf2';   % Comment out if analysis on whole economy.

L = csvread(strcat(workdir,'/calib_firmL',manuf,'.raw'));           % firms x countries
Cust = csvread(strcat(workdir,'/calib_firmCustid',manuf,'.raw'));           % firms x countries
R = csvread(strcat(workdir,'/calib_firmR',manuf,'.raw'));           % firms x countries
pi = csvread(strcat(workdir,'/calib_firmpi',manuf,'.raw'));           % firms x countries
    
fid = fopen(strcat(workdir,'/countries',manuf,'.raw'));
N = textscan(fid, '%s','delimiter',',');
fclose(fid);
cty = N{1};

K = size(L,2); % countries
firms = size(L,1);

if manuf 
  SS = csvread(strcat(workdir,'/calib_firmpiNO',manuf,'.raw'));           % firms x 1    
  piNO = SS(:,1);
  revenue = SS(:,2);
else piNO = ones(firms,1)*.8; end;

% Note: pi is the import share, excluding purchases from Norway. piNO is
% the share of purchases from Norway. Add Norway to the pi. 
pp = repmat(piNO,1,K);
pi2 = [piNO pi.*(1-pp)];


%
% Fixed point
%
rhohat = ones(K,1);
err = 1;
count = 0;
while err>eps
    rr = repmat([1;rhohat]',firms,1);   % firms x K
    % Normalization: rhohat=1 for Norway.

    Omega_g = sum(rr.*pi2,2);  % firms x 1
    rhohat_new = repmat(Omega_g,1,K).*R;  % firms x K
    
    % Take the median across firms that have positive imports
    for j=1:K
        tmp = pi(:,j);
        rr = rhohat_new(:,j);
        rhohat_new2(j,1) = median(rr(tmp>0));
        %rhohat_new2(j,1) = mean(rr(tmp>0));
    end
 
    err = (rhohat_new2-rhohat)'*(rhohat_new2-rhohat);   
    count = count + 1
    rhohat = rhohat_new2;
end


%% Fit of the model

close all;
savefigures=0;

disp('Median Omega_g'); disp(median(log(Omega_g)));
disp('Mean Omega_g'); disp(mean(log(Omega_g)));
disp('Stdev Omega)g'); disp(std(log(Omega_g)));
disp('1st decile Omega_g'); disp(prctile(log(Omega_g),10));

% Weighted average, usig revenue as weights
if manuf 
    revenue(revenue<=0)=1;
    weights = revenue./sum(revenue);
    disp('Weighted Omega_g'); disp(sum(log(Omega_g).*weights));
end

figure(1); hist(log(Omega_g),100); xlabel('lnOmega^\gamma','FontSize',12)
if (savefigures) saveas(gcf,strcat(workdir,'/graph/omega_g',manuf,'.eps')); end;
if (savefigures) saveas(gcf,strcat(workdir,'/graph/omega_g',manuf,'.pdf')); end;

Rhatm = repmat(rhohat',firms,1)./repmat(Omega_g,1,K);
Lhatm = Rhatm;

% Relationship between change in # connections in model/data
active = L>0;

disp('Median log R data/model'); disp([median(log(R(active))) median(log(Rhatm(active)))]);
disp('Mean log R data/model'); disp([mean(log(R(active))) mean(log(Rhatm(active)))]);
disp('Std log R data/model'); disp([std(log(R(active))) std(log(Rhatm(active)))]);

disp('Median L data/model'); disp([median(log(L(active))) median(log(Rhatm(active)))]);
disp('Mean L data/model'); disp([mean(log(L(active))) mean(log(Rhatm(active)))]);
disp('Std L data/model'); disp([std(log(L(active))) std(log(Rhatm(active)))]);

disp('For firms with >1 customer');
disp('...Median L data/model'); disp([median(log(L(Cust>1))) median(log(Rhatm(Cust>1)))]);
disp('...Mean L data/model'); disp([mean(log(L(Cust>1))) mean(log(Rhatm(Cust>1)))]);
disp('...Std L data/model'); disp([std(log(L(Cust>1))) std(log(Rhatm(Cust>1)))]);

% Weighted average, usig revenue as weights
if manuf 
    ww = repmat(weights,1,K); ss = sum(ww(active)); ss2 = sum(ww(Cust>1));
    disp('Weighted mean R data'); disp(sum(log(R(active)).*ww(active))/ss);
    disp('Weighted mean L data'); disp(sum(log(L(active)).*ww(active))/ss);
    disp('Weighted mean L data, L>1'); disp(sum(log(L(Cust>1)).*ww(Cust>1))/ss2);
    disp('Weighted mean R model'); disp(sum(log(Rhatm(active)).*ww(active))/ss);
    disp('Weighted mean L model'); disp(sum(log(Rhatm(active)).*ww(active))/ss);
    disp('Weighted mean L model, L>1'); disp(sum(log(Rhatm(Cust>1)).*ww(Cust>1))/ss2);
end

% By country
for j=1:K
    tmp = L(:,j);    
    active2 = tmp>0;
    wwj = ww(:,j);
    ssj = sum(wwj(active2));
    Lj = log(tmp(active2));
    tmp = R(:,j);
    Rj = log(tmp(active2));   
    tmp = Rhatm(:,j);
    Rmj = log(tmp(active2));    
    Ljmedian(j,1) = median(Lj);   % data
    Ljmean(j,1) = mean(Lj);       % data
    Rjmedian(j,1) = median(Rj);   % data
    Rjmean(j,1) = mean(Rj);       % data
    Rmjmedian(j,1) = median(Rmj); % model
    Rmjmean(j,1) = mean(Rmj);     % model
end

figure(2); 
xmin=-.3; xmax=.2;
scatter(Ljmean,Rmjmean,'w');
for j=1:K, text(Ljmean(j),Rmjmean(j),cty(j)), end;
xlabel('Data','FontSize',12); ylabel('Model','FontSize',12); 
hold on; plot(linspace(xmin,xmax,30),linspace(xmin,xmax,30),':');
axis([xmin xmax xmin xmax]);

if (savefigures) saveas(gcf,strcat(workdir,'/graph/model_fit',manuf,'.eps')); end;
if (savefigures) saveas(gcf,strcat(workdir,'/graph/model_fit',manuf,'.pdf')); end;

display('Correlation R median data/model');  disp(corr(Rjmedian,Rmjmedian));
display('Correlation L median data/model'); disp(corr(Ljmedian,Rmjmedian));

display('Correlation R mean data/model');  disp(corr(Rjmean,Rmjmean));
display('Correlation L mean data/model'); disp(corr(Ljmean,Rmjmean));


